package cn.langpy.simforkjoin.core;


import cn.langpy.simforkjoin.annotation.ForkJoin;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.*;
import java.util.logging.Logger;


@Aspect
@Component
public class ForkJoinHandler {
    public static Logger log = Logger.getLogger(ForkJoinHandler.class.toString());

    @Autowired
    ContextTask contextTask;

    @Pointcut("@annotation(cn.langpy.simforkjoin.annotation.ForkJoin)")
    public void preProcess() {

    }

    @Around("preProcess()")
    public Object before(ProceedingJoinPoint joinPoint) throws Throwable {
        Object[] args = joinPoint.getArgs();
        ForkJoin forkJoin = ((MethodSignature) joinPoint.getSignature()).getMethod().getAnnotation(ForkJoin.class);
        contextTask.validate(forkJoin, args);
        int threshold = forkJoin.threshold();
        List arg = (List) args[0];
        int len = arg.size();
        if (len < threshold) {
            return joinPoint.proceed(args);
        }
        ExecutorService threadExecutor = contextTask.getExecutor(forkJoin);
        CompletionService completionService = new ExecutorCompletionService(threadExecutor);
        int n = contextTask.execute(completionService, joinPoint, args, threshold);
        List results = new ArrayList();
        for (int i = 0; i < n; i++) {
            List result = (List) completionService.take().get();
            if (result != null) {
                results.addAll(result);
            }
        }
        contextTask.closeThreadPool(threadExecutor);
        if (forkJoin.isReturn()) {
            return results;
        }
        return null;
    }


}
